import torch
import torch.nn as nn
import torch.backends.cudnn
import wandb

import os
import json 

import train
import val
import test

import models.get_model
import datasets.cifar_loader
import utils.train_utils
import gpytorch
from utils.seed_utils import set_seed
import utils.utils

import warmup_scheduler

wandb.login()
def main(args):
    if args.attn_type == 'softmax':
        save_path = os.path.join(args.save_dir, f"{args.dataset}_{args.attn_type}_{args.model}_{args.seed}")
        group = "VIT-CIFAR"
    elif args.attn_type == 'kep_svgp':
        save_path = os.path.join(
            args.save_dir,
            f"{args.dataset}_{args.attn_type}_{args.model}_ksvdlayer{args.ksvd_layers}_ksvd{args.eta_ksvd}_kl{args.eta_kl}_{args.seed}"
        )
        group = "KEP-SVGP-CIFAR"
    elif args.attn_type == 'sgpa':
        save_path = os.path.join(
            args.save_dir,
            f"{args.dataset}_{args.attn_type}_{args.model}_{args.seed}"
        )
        group = "SGPA-CIFAR"
  
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    wandb.init(project='Difformer', 
               group=group,
               name=f"Seed_{args.seed}",
               config=vars(args))

    # Set seed everything
    set_seed(args.seed)

    logger = utils.utils.get_logger(save_path)
    logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    train_loader, val_loader, _, nb_cls = datasets.cifar_loader.get_loader(
        args.dataset, args.train_dir, args.val_dir, args.test_dir, args.batch_size
    )

    for run in range(args.nb_run):
        prefix = f'{run + 1} / {args.nb_run} Running'
        logger.info(100*'#' + '\n' + prefix)

        ## define model
        net = models.get_model.get_model(args.model, nb_cls, logger, args)
        # print(net)
        print(sum(p.numel() for p in net.parameters() if p.requires_grad))
        net.cuda()
        
        ## define optimizer with warm-up
        optimizer = torch.optim.Adam(
            net.parameters(),
            lr=args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay
        )
        base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.nb_epochs, eta_min=args.min_lr
        )
        scheduler = warmup_scheduler.GradualWarmupScheduler(
            optimizer,
            multiplier=1.,
            total_epoch=args.warmup_epoch,
            after_scheduler=base_scheduler
        )
        
        ## make logger
        best_acc, best_auroc, best_aurc = 0, 0, 1e6

        ## start training
        for epoch in range(args.nb_epochs):
            train.train(train_loader, net, optimizer, epoch, logger, args)
            
            scheduler.step()

            # validation
            net_val = net
            res = val.validation(val_loader, net_val, args) 
            log = [f"{key}: {res[key]:.3f}" for key in res]
            msg = '################## \n ---> Validation Epoch {:d}\t'.format(epoch) + '\t'.join(log)
            logger.info(msg)

            wandb.log({f"Val/{key}": res[key] for key in res}, step=epoch)

            if res['Acc.'] > best_acc:
                acc = res['Acc.']
                msg = f'Accuracy improved from {best_acc:.2f} to {acc:.2f}!!!'
                logger.info(msg)
                best_acc = acc
                torch.save(net_val.state_dict(), os.path.join(save_path, f'best_acc_net_{run+1}.pth'))
            
            if res['AUROC'] > best_auroc:
                auroc = res['AUROC']
                msg = f'AUROC improved from {best_auroc:.2f} to {auroc:.2f}!!!'
                logger.info(msg)
                best_auroc = auroc
                torch.save(net_val.state_dict(), os.path.join(save_path, f'best_auroc_net_{run+1}.pth'))
        
            if res['AURC'] < best_aurc:
                aurc = res['AURC']
                msg = f'AURC decreased from {best_aurc:.2f} to {aurc:.2f}!!!'
                logger.info(msg)
                best_aurc = aurc
                torch.save(net_val.state_dict(), os.path.join(save_path, f'best_aurc_net_{run+1}.pth'))
                

def main_svdkl(args):
    if args.attn_type == 'softmax':
        save_path = os.path.join(args.save_dir, f"{args.dataset}_{args.attn_type}_{args.model}_{args.seed}")
        pretrained_path = os.path.join(args.pretrained_dir, f"{args.dataset}_{args.attn_type}_vit_cifar_{args.pretrained_seed}")
        group = "SVDKL-VIT-CIFAR"

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    wandb.init(project='Difformer', 
               group=group,
               name=f"Seed_{args.seed}_svdkl",
               config=vars(args))

    # Set seed everything
    set_seed(args.seed)

    logger = utils.utils.get_logger(save_path)
    logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    train_loader, val_loader, _, nb_cls = datasets.cifar_loader.get_loader(
        args.dataset, args.train_dir, args.val_dir, args.test_dir, args.batch_size
    )

    for run in range(args.nb_run):
        prefix = f'{run + 1} / {args.nb_run} Running'
        logger.info(100*'#' + '\n' + prefix)

        ## define model
        net = models.get_model.get_model(args.model, nb_cls, logger, args)
        net.feature_extractor.load_state_dict(torch.load(os.path.join(pretrained_path, f'best_acc_net_{run + 1}.pth')))
        for params in net.feature_extractor.parameters():
            params.requires_grad = False
        print(net)
        print(sum(p.numel() for p in net.parameters() if p.requires_grad))
        net.cuda()
        likelihood = gpytorch.likelihoods.SoftmaxLikelihood(num_features=args.hdim, num_classes=nb_cls).cuda()
        ## define optimizer with warm-up
        args.lr = 0.1
        optimizer = torch.optim.SGD([
            {'params': net.feature_extractor.parameters(), 'weight_decay': 1e-4},
            {'params': net.gp_layer.hyperparameters(), 'lr': args.lr * 0.01},
            {'params': net.gp_layer.variational_parameters()},
            {'params': likelihood.parameters()},
        ], lr=args.lr, momentum=0.9, nesterov=True, weight_decay=0)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0.5 * args.nb_epochs, 0.75 * args.nb_epochs], gamma=0.1)
        mll = gpytorch.mlls.VariationalELBO(likelihood, net.gp_layer, num_data=len(train_loader.dataset))
        
        ## make logger
        best_acc, best_auroc, best_aurc = 0, 0, 1e6

        ## start training
        for epoch in range(args.nb_epochs):
            train_log = {
                'Tot. Loss': utils.utils.AverageMeter(),
                'LR': utils.utils.AverageMeter(),
            }
            msg = '####### --- Training Epoch {:d} --- #######'.format(epoch)
            logger.info(msg)
            # train.train(train_loader, net, optimizer, epoch, logger, args)
            with gpytorch.settings.use_toeplitz(False):
                net.train()
                likelihood.train()
                
                for i, (data, target) in enumerate(train_loader):
                    data, target = data.cuda(), target.cuda()
                    optimizer.zero_grad()
                    output = net(data)
                    loss = -mll(output, target)
                    loss.backward()
                    optimizer.step()
            
                    for param_group in optimizer.param_groups:
                        lr = param_group["lr"]
                        break

                    train_log['Tot. Loss'].update(loss.item(), data.size(0))
                    train_log['LR'].update(lr, data.size(0))

                    if i % 100 == 99:
                        log = ['LR : {:.5f}'.format(train_log['LR'].avg)] + [
                            key + ': {:.2f}'.format(train_log[key].avg) for key in train_log if key != 'LR'
                        ]
                        msg = 'Epoch {:d} \t Batch {:d}\t'.format(epoch, i) + '\t'.join(log)
                        logger.info(msg)
                        for key in train_log:
                            train_log[key] = utils.utils.AverageMeter()

                # Replace writer.add_scalar with wandb.log
                wandb.log({f"Train/{key}": train_log[key].avg for key in train_log}, step=epoch)
            
            scheduler.step()

            # validation
            net_val = net
            res = val.validation(val_loader, (net_val, likelihood), args, method='svdkl') 
            log = [f"{key}: {res[key]:.3f}" for key in res]
            msg = '################## \n ---> Validation Epoch {:d}\t'.format(epoch) + '\t'.join(log)
            logger.info(msg)

            wandb.log({f"Val/{key}": res[key] for key in res}, step=epoch)

            if res['Acc.'] > best_acc:
                acc = res['Acc.']
                msg = f'Accuracy improved from {best_acc:.2f} to {acc:.2f}!!!'
                logger.info(msg)
                best_acc = acc
                torch.save(net_val.state_dict(), os.path.join(save_path, f'best_acc_net_{run+1}.pth'))
                torch.save(likelihood.state_dict(), os.path.join(save_path, f'best_acc_likelihood_{run+1}.pth'))
            
            if res['AUROC'] > best_auroc:
                auroc = res['AUROC']
                msg = f'AUROC improved from {best_auroc:.2f} to {auroc:.2f}!!!'
                logger.info(msg)
                best_auroc = auroc
                # torch.save(net_val.state_dict(), os.path.join(save_path, f'best_auroc_net_{run+1}.pth'))
        
            if res['AURC'] < best_aurc:
                aurc = res['AURC']
                msg = f'AURC decreased from {best_aurc:.2f} to {aurc:.2f}!!!'
                logger.info(msg)
                best_aurc = aurc
                # torch.save(net_val.state_dict(), os.path.join(save_path, f'best_aurc_net_{run+1}.pth')) 
      
def main_diffusion(args):
    if args.attn_type == 'softmax':
        if args.backbone == 'mlp':
            save_path = os.path.join(args.save_dir, f"{args.dataset}_{args.attn_type}_{args.model}_{args.seed}_{args.backbone}_{args.mlp_hdim1}_{args.mlp_hdim2}_{args.mlp_hdim3}_{args.mlp_dropout}_{args.lr}_{args.clip}_{args.nb_epochs}")
        elif args.backbone == 'lstm' or args.backbone == 'gru':
            save_path = os.path.join(args.save_dir, f"{args.dataset}_{args.attn_type}_{args.model}_{args.seed}_{args.backbone}_{args.rnn_hidden}_{args.rnn_num_layers}_{args.rnn_dropout}_{args.rnn_low_dim}_{args.lr}_{args.nb_epochs}")
        elif args.backbone == 'transformer':
            save_path = os.path.join(args.save_dir, f"{args.dataset}_{args.attn_type}_{args.model}_{args.seed}_{args.backbone}_{args.trans_depth}_{args.trans_num_heads}_{args.trans_mlp_ratio}_{args.trans_dropout}_{args.lr}_{args.nb_epochs}")
        pretrained_path = os.path.join(args.pretrained_dir, f"{args.dataset}_{args.attn_type}_vit_cifar_{args.pretrained_seed}")
        group = "VIT-DiT"
    elif args.attn_type == 'kep_svgp':
        if args.backbone == 'mlp':
            save_path = os.path.join(
                args.save_dir,
                f"{args.dataset}_{args.attn_type}_{args.model}_ksvdlayer{args.ksvd_layers}_ksvd{args.eta_ksvd}_kl{args.eta_kl}_{args.seed}_{args.backbone}_{args.mlp_hdim1}_{args.mlp_hdim2}_{args.mlp_hdim3}_{args.mlp_dropout}_{args.rnn_low_dim}_{args.lr}_{args.clip}_{args.nb_epochs}"
            )
        elif args.backbone == 'lstm' or args.backbone == 'gru':
            save_path = os.path.join(
                args.save_dir,
                f"{args.dataset}_{args.attn_type}_{args.model}_ksvdlayer{args.ksvd_layers}_ksvd{args.eta_ksvd}_kl{args.eta_kl}_{args.seed}_{args.backbone}_{args.rnn_hidden}_{args.rnn_num_layers}_{args.rnn_dropout}_{args.rnn_low_dim}_{args.lr}_{args.nb_epochs}"
            )
        elif args.backbone == 'transformer':
            save_path = os.path.join(
                args.save_dir,
                f"{args.dataset}_{args.attn_type}_{args.model}_ksvdlayer{args.ksvd_layers}_ksvd{args.eta_ksvd}_kl{args.eta_kl}_{args.seed}_{args.backbone}_{args.trans_depth}_{args.trans_num_heads}_{args.trans_mlp_ratio}_{args.trans_dropout}_{args.lr}_{args.nb_epochs}"
            )
        pretrained_path = os.path.join(
            args.pretrained_dir,
            f"{args.dataset}_{args.attn_type}_vit_cifar_ksvdlayer{args.ksvd_layers}_ksvd{args.eta_ksvd}_kl{args.eta_kl}_{args.pretrained_seed}"
        )
        group = "KEP-SVGP-DiT"

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    wandb.init(project='Difformer', 
               group=group,
               name=f"Diffusion {args.run_name}: seed_{args.seed}_lr_{args.lr}_pretrained_seed_{args.pretrained_seed}_ksvd_layers_{args.ksvd_layers}_lambda_mean_{args.lambda_mean}_var_{args.lambda_var}_ce_{args.lambda_ce}_batchsize_{args.batch_size}_epochs_{args.nb_epochs}",
            #    name=f"Diffusion {args.run_name}: seed_{args.seed}_lr_{args.lr}_clip_{args.clip}_pretrained_seed_{args.pretrained_seed}_mlp_dropout_{args.mlp_dropout}_ksvd_layers_{args.ksvd_layers}_lambda_mean_{args.lambda_mean}_var_{args.lambda_var}_ce_{args.lambda_ce}_batchsize_{args.batch_size}_architecture_{args.mlp_hdim1}_{args.mlp_hdim2}_{args.mlp_hdim3}_epochs_{args.nb_epochs}_adversarial_noise_{args.adversarial_noise}_adversarial_samples_{args.adversarial_samples}_rnn_hidden_{args.rnn_hidden}_rnn_num_layers_{args.rnn_num_layers}_rnn_dropout_{args.rnn_dropout}",
               config=vars(args))

    # Set seed everything
    set_seed(args.seed)

    logger = utils.utils.get_logger(save_path)
    logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    train_loader, val_loader, test_loader, nb_cls = datasets.cifar_loader.get_loader(
        args.dataset, args.train_dir, args.val_dir, args.test_dir, args.batch_size
    )

    for run in range(args.nb_run):
        prefix = f'{run + 1} / {args.nb_run} Running'
        logger.info(100*'#' + '\n' + prefix)

        ## define model
        net = models.get_model.get_model(args.model, nb_cls, logger, args)
        # net.load_state_dict(torch.load(os.path.join(save_path, f'best_acc_net_{run + 1}_diffusion.pth')))
        print(net)
        print(sum(p.numel() for p in net.parameters() if p.requires_grad))
        net.cuda()
        pretrained_ViT = models.get_model.get_model('q_distribution', nb_cls, logger, args)
        pretrained_ViT.load_state_dict(torch.load(os.path.join(pretrained_path, f'best_acc_net_{run + 1}.pth')))
        pretrained_ViT.cuda()
        net.emb.load_state_dict(pretrained_ViT.emb.state_dict())
        net.pos_emb.data.copy_(pretrained_ViT.pos_emb.data)
        net.ln.load_state_dict(pretrained_ViT.enc[args.depth - 1].la2.state_dict())
        net.solution_head_1.load_state_dict(pretrained_ViT.enc[args.depth - 1].mlp.state_dict())
        net.solution_head_2.load_state_dict(pretrained_ViT.fc.state_dict())
        
        ## define optimizer with warm-up
        optimizer = torch.optim.Adam(
            net.parameters(),
            lr=args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay
        )
        base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.nb_epochs, eta_min=args.min_lr
        )
        scheduler = warmup_scheduler.GradualWarmupScheduler(
            optimizer,
            multiplier=1.,
            total_epoch=args.warmup_epoch,
            after_scheduler=base_scheduler
        )
        
        ## make logger
        best_acc, best_auroc, best_aurc = 0, 0, 1e6

        ## start training
        for epoch in range(args.nb_epochs):
            train.train_diffusion(train_loader, net, optimizer, epoch, logger, args, pretrained_ViT)

            scheduler.step()

            # validation
            # if epoch % args.update_ema_interval == 0:
            #     apply_ema(args, ema, net)
            net_val = net
            res = val.validation_diffusion(val_loader, net_val, args, pretrained_ViT) 
            log = [f"{key}: {res[key]:.3f}" for key in res]
            msg = '################## \n ---> Validation Epoch {:d}\t'.format(epoch) + '\t'.join(log)
            logger.info(msg)

            wandb.log({f"Val/{key}": res[key] for key in res}, step=epoch)

            if res['Acc.'] > best_acc:
                acc = res['Acc.']
                msg = f'Accuracy improved from {best_acc:.2f} to {acc:.2f}!!!'
                logger.info(msg)
                best_acc = acc
                torch.save(net_val.state_dict(), os.path.join(save_path, f'best_acc_net_{run+1}_diffusion_{args.backbone}.pth'))
                # torch.save(pretrained_ViT.state_dict(), os.path.join(save_path, f'best_acc_net_{run + 1}_vit_fc.pth'))
            
            if res['AUROC'] > best_auroc:
                auroc = res['AUROC']
                msg = f'AUROC improved from {best_auroc:.2f} to {auroc:.2f}!!!'
                logger.info(msg)
                best_auroc = auroc
                torch.save(net_val.state_dict(), os.path.join(save_path, f'best_auroc_net_{run+1}_diffusion_{args.backbone}.pth'))
        
            if res['AURC'] < best_aurc:
                aurc = res['AURC']
                msg = f'AURC decreased from {best_aurc:.2f} to {aurc:.2f}!!!'
                logger.info(msg)
                best_aurc = aurc
                torch.save(net_val.state_dict(), os.path.join(save_path, f'best_aurc_net_{run+1}_diffusion_{args.backbone}.pth'))
        
        torch.save(net.state_dict(), os.path.join(save_path, f'last_net_{run+1}_diffusion_{args.backbone}.pth'))

def main_distillation(args):
    if args.attn_type == 'softmax':
        if args.backbone == 'mlp':
            save_path = os.path.join(args.save_dir, f"{args.dataset}_{args.attn_type}_{args.model}_{args.seed}_{args.backbone}_{args.mlp_hdim1}_{args.mlp_hdim2}_{args.mlp_hdim3}_{args.mlp_dropout}_{args.lr}_{args.clip}_{args.nb_epochs}")
        elif args.backbone == 'lstm' or args.backbone == 'gru':
            save_path = os.path.join(args.save_dir, f"{args.dataset}_{args.attn_type}_{args.model}_{args.seed}_{args.backbone}_{args.rnn_hidden}_{args.rnn_num_layers}_{args.rnn_dropout}_{args.rnn_low_dim}_{args.lr}_{args.nb_epochs}")
        elif args.backbone == 'transformer':
            save_path = os.path.join(args.save_dir, f"{args.dataset}_{args.attn_type}_{args.model}_{args.seed}_{args.backbone}_{args.trans_depth}_{args.trans_num_heads}_{args.trans_mlp_ratio}_{args.trans_dropout}_{args.lr}_{args.nb_epochs}")
        pretrained_path = os.path.join(args.pretrained_dir, f"{args.dataset}_{args.attn_type}_vit_cifar_{args.pretrained_seed}")
        group = "VIT-DiT-Distillation"
    elif args.attn_type == 'kep_svgp':
        if args.backbone == 'mlp':
            save_path = os.path.join(
                args.save_dir,
                f"{args.dataset}_{args.attn_type}_{args.model}_ksvdlayer{args.ksvd_layers}_ksvd{args.eta_ksvd}_kl{args.eta_kl}_{args.seed}_{args.backbone}_{args.mlp_hdim1}_{args.mlp_hdim2}_{args.mlp_hdim3}_{args.mlp_dropout}_{args.rnn_low_dim}_{args.lr}_{args.clip}_{args.nb_epochs}"
            )
        elif args.backbone == 'lstm' or args.backbone == 'gru':
            save_path = os.path.join(
                args.save_dir,
                f"{args.dataset}_{args.attn_type}_{args.model}_ksvdlayer{args.ksvd_layers}_ksvd{args.eta_ksvd}_kl{args.eta_kl}_{args.seed}_{args.backbone}_{args.rnn_hidden}_{args.rnn_num_layers}_{args.rnn_dropout}_{args.rnn_low_dim}_{args.lr}_{args.nb_epochs}"
            )
        elif args.backbone == 'transformer':
            save_path = os.path.join(
                args.save_dir,
                f"{args.dataset}_{args.attn_type}_{args.model}_ksvdlayer{args.ksvd_layers}_ksvd{args.eta_ksvd}_kl{args.eta_kl}_{args.seed}_{args.backbone}_{args.trans_depth}_{args.trans_num_heads}_{args.trans_mlp_ratio}_{args.trans_dropout}_{args.lr}_{args.nb_epochs}"
            )
        pretrained_path = os.path.join(
            args.pretrained_dir,
            f"{args.dataset}_{args.attn_type}_vit_cifar_ksvdlayer{args.ksvd_layers}_ksvd{args.eta_ksvd}_kl{args.eta_kl}_{args.pretrained_seed}"
        )
        group = "KEP-SVGP-DiT-Distillation"

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    wandb.init(project='Difformer', 
               group=group,
               name=f"Diffusion {args.run_name}: seed_{args.seed}_lr_{args.lr}_pretrained_seed_{args.pretrained_seed}_ksvd_layers_{args.ksvd_layers}_lambda_mean_{args.lambda_mean}_var_{args.lambda_var}_ce_{args.lambda_ce}_batchsize_{args.batch_size}_epochs_{args.nb_epochs}",
            #    name=f"Diffusion {args.run_name}: seed_{args.seed}_lr_{args.lr}_clip_{args.clip}_pretrained_seed_{args.pretrained_seed}_mlp_dropout_{args.mlp_dropout}_ksvd_layers_{args.ksvd_layers}_lambda_mean_{args.lambda_mean}_var_{args.lambda_var}_ce_{args.lambda_ce}_batchsize_{args.batch_size}_architecture_{args.mlp_hdim1}_{args.mlp_hdim2}_{args.mlp_hdim3}_epochs_{args.nb_epochs}_adversarial_noise_{args.adversarial_noise}_adversarial_samples_{args.adversarial_samples}_rnn_hidden_{args.rnn_hidden}_rnn_num_layers_{args.rnn_num_layers}_rnn_dropout_{args.rnn_dropout}",
               config=vars(args))

    # Set seed everything
    set_seed(args.seed)

    logger = utils.utils.get_logger(save_path)
    logger.info(json.dumps(vars(args), indent=4, sort_keys=True))
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    train_loader, val_loader, test_loader, nb_cls = datasets.cifar_loader.get_loader(
        args.dataset, args.train_dir, args.val_dir, args.test_dir, args.batch_size
    )

    for run in range(args.nb_run):
        prefix = f'{run + 1} / {args.nb_run} Running'
        logger.info(100*'#' + '\n' + prefix)

        ## define model
        if args.model == 'diffusion_distillation':
            net = models.get_model.get_model('diffusion', nb_cls, logger, args)
            # net.load_state_dict(torch.load(os.path.join(save_path, f'best_acc_net_{run + 1}_diffusion.pth')))
            print(net)
            print(sum(p.numel() for p in net.parameters() if p.requires_grad))
            net.cuda()
            pretrained_ViT = models.get_model.get_model('vit_cifar', nb_cls, logger, args)
            pretrained_ViT.load_state_dict(torch.load(os.path.join(pretrained_path, f'best_acc_net_{run + 1}.pth')))
            pretrained_ViT.cuda()
            net.emb.load_state_dict(pretrained_ViT.emb.state_dict())
            net.pos_emb.data.copy_(pretrained_ViT.pos_emb.data)
            net.ln.load_state_dict(pretrained_ViT.enc[args.depth - 1].la2.state_dict())
            net.solution_head_1.load_state_dict(pretrained_ViT.enc[args.depth - 1].mlp.state_dict())
            net.solution_head_2.load_state_dict(pretrained_ViT.fc.state_dict())
        elif args.model == 'vit_cifar_distillation':
            net = models.get_model.get_model('vit_cifar', nb_cls, logger, args)
            # net.load_state_dict(torch.load(os.path.join(save_path, f'best_acc_net_{run + 1}_diffusion.pth')))
            print(net)
            print(sum(p.numel() for p in net.parameters() if p.requires_grad))
            net.cuda()
            pretrained_ViT = models.get_model.get_model('vit_cifar_teacher', nb_cls, logger, args)
            pretrained_ViT.load_state_dict(torch.load(os.path.join(pretrained_path, f'best_acc_net_{run + 1}.pth')))
            pretrained_ViT.cuda()
            
        ## define optimizer with warm-up
        optimizer = torch.optim.Adam(
            net.parameters(),
            lr=args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay
        )
        base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.nb_epochs, eta_min=args.min_lr
        )
        scheduler = warmup_scheduler.GradualWarmupScheduler(
            optimizer,
            multiplier=1.,
            total_epoch=args.warmup_epoch,
            after_scheduler=base_scheduler
        )
        
        ## make logger
        best_acc, best_auroc, best_aurc = 0, 0, 1e6

        ## start training
        for epoch in range(args.nb_epochs):
            train.train_distillation(train_loader, net, optimizer, epoch, logger, args, pretrained_ViT)
            
            scheduler.step()
            
            net_val = net
            res = val.validation_diffusion(val_loader, net_val, args, pretrained_ViT) 
            log = [f"{key}: {res[key]:.3f}" for key in res]
            msg = '################## \n ---> Validation Epoch {:d}\t'.format(epoch) + '\t'.join(log)
            logger.info(msg)

            wandb.log({f"Val/{key}": res[key] for key in res}, step=epoch)

            if res['Acc.'] > best_acc:
                acc = res['Acc.']
                msg = f'Accuracy improved from {best_acc:.2f} to {acc:.2f}!!!'
                logger.info(msg)
                best_acc = acc
                torch.save(net_val.state_dict(), os.path.join(save_path, f'best_acc_net_{run+1}_{args.temperature}_{args.lambda_mean}_{args.lambda_var}_{args.lambda_ce}.pth'))
                # torch.save(pretrained_ViT.state_dict(), os.path.join(save_path, f'best_acc_net_{run + 1}_vit_fc.pth'))
            
            if res['AUROC'] > best_auroc:
                auroc = res['AUROC']
                msg = f'AUROC improved from {best_auroc:.2f} to {auroc:.2f}!!!'
                logger.info(msg)
                best_auroc = auroc
                # torch.save(net_val.state_dict(), os.path.join(save_path, f'best_auroc_net_{run+1}_diffusion_{args.backbone}.pth'))
        
            if res['AURC'] < best_aurc:
                aurc = res['AURC']
                msg = f'AURC decreased from {best_aurc:.2f} to {aurc:.2f}!!!'
                logger.info(msg)
                best_aurc = aurc
                # torch.save(net_val.state_dict(), os.path.join(save_path, f'best_aurc_net_{run+1}_diffusion_{args.backbone}.pth'))
        
        torch.save(net.state_dict(), os.path.join(save_path, f'last_net_{run+1}.pth'))
        
if __name__ == '__main__':
    args = utils.train_utils.get_args_parser()
    if args.model == 'diffusion':
        main_diffusion(args)
        test.test_diffusion(args)
        wandb.finish()
    elif args.model == 'svdkl':
        main_svdkl(args)
        test.test(args)
        wandb.finish()
    elif args.model == 'diffusion_distillation' or args.model == 'vit_cifar_distillation':
        main_distillation(args)
        test.test_distillation(args)
        wandb.finish()
    else:
        main(args)
        test.test(args)
        wandb.finish()
